{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp callback.preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "from fastai.basics import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *\n",
    "from fastai.test_utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predictions callbacks\n",
    "\n",
    "> Various callbacks to customize get_preds behaviors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MCDropoutCallback\n",
    "\n",
    "> Turns on dropout during inference, allowing you to call Learner.get_preds multiple times to approximate your model uncertainty using [Monte Carlo Dropout](https://arxiv.org/pdf/1506.02142.pdf)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class MCDropoutCallback(Callback):\n",
    "    def before_validate(self):\n",
    "        for m in [m for m in flatten_model(self.model) if 'dropout' in m.__class__.__name__.lower()]:\n",
    "            m.train()\n",
    "    \n",
    "    def after_validate(self):\n",
    "        for m in [m for m in flatten_model(self.model) if 'dropout' in m.__class__.__name__.lower()]:\n",
    "            m.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10, 32, 1])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn = synth_learner()\n",
    "\n",
    "# Call get_preds 10 times, then stack the predictions, yielding a tensor with shape [# of samples, batch_size, ...]\n",
    "dist_preds = []\n",
    "for i in range(10):\n",
    "    preds, targs = learn.get_preds(cbs=[MCDropoutCallback()])\n",
    "    dist_preds += [preds]\n",
    "\n",
    "torch.stack(dist_preds).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
